# --------------------------------------------------------
# modified from Hora

import os
import time
import torch

from hora.algo.ppo.experience import ExperienceBuffer
from hora.algo.models.models import ActorCritic, ActorCriticMaskCompensator, ActorCriticAsymmetricBC
from hora.algo.models.running_mean_std import RunningMeanStd

from hora.utils.misc import AverageScalarMeter, Average1DTensorMeter

from tensorboardX import SummaryWriter


def batched_index_select(values, indices, dim = 1):
  value_dims = values.shape[(dim + 1):]
  values_shape, indices_shape = map(lambda t: list(t.shape), (values, indices))
  indices = indices[(..., *((None,) * len(value_dims)))]
  indices = indices.expand(*((-1,) * len(indices_shape)), *value_dims)
  value_expand_len = len(indices_shape) - (dim + 1)
  values = values[(*((slice(None),) * dim), *((None,) * value_expand_len), ...)]

  value_expand_shape = [-1] * len(values.shape)
  expand_slice = slice(dim, (dim + value_expand_len))
  value_expand_shape[expand_slice] = indices.shape[expand_slice]
  values = values.expand(*value_expand_shape)

  dim += value_expand_len
  return values.gather(dim, indices)




class PPO(object):
    def __init__(self, env, output_dif, full_config):
        self.device = full_config['rl_device']
        self.network_config = full_config.train.network
        self.ppo_config = full_config.train.ppo
        # ---- build environment ----
        self.env = env
        self.num_actors = self.ppo_config['num_actors']
        action_space = self.env.action_space
        self.actions_num = action_space.shape[0]
        self.actions_low = torch.from_numpy(action_space.low.copy()).float().to(self.device)
        self.actions_high = torch.from_numpy(action_space.high.copy()).float().to(self.device)
        self.observation_space = self.env.observation_space
        self.obs_shape = self.observation_space.shape
        # ---- Priv Info ----
        self.priv_info_dim = self.ppo_config['priv_info_dim']
        self.priv_info = self.ppo_config['priv_info']
        self.proprio_adapt = self.ppo_config['proprio_adapt']
        
        # ---- Action compensator with real world model ----
        self.train_action_compensator_w_real_wm = self.ppo_config.get('trainActionCompensatorWRealWM', False) # get action compensator 
        self.policy_model_ckpt_fn = self.ppo_config.get('policyModelCheckpoint', '') # get policy model checkpoint
        self.action_compensator_obs_dim = self.ppo_config.get('actionCompensatorObsDim', 112) # get action compensator obs dim 
        self.use_bc_base_policy = self.ppo_config.get('useBCBasePolicy', False)
        self.bc_base_policy_ckpt_fn = self.ppo_config.get('bcBasePolicyCheckpoint', '')
        self.hierarchical_compensator = self.ppo_config.get('hierarchicalCompensator', False)
        self.multi_base_policy = self.ppo_config.get('multiBasePolicy', False)
        self.train_action_compensator = self.env.train_action_compensator # 
        self.per_joint_action_compensator = self.env.per_joint_action_compensator
        self.train_action_compensator_uan = self.env.train_action_compensator_uan
        
        # ---- Tune BC Model ----
        self.tune_bc_model = self.ppo_config.get('tuneBCModel', False)
        self.bc_model_history_length = self.ppo_config.get('bcModelHistoryLength', 10)
        self.tune_bc_via_compensator_model = self.ppo_config.get('tuneBCviaCompensatorModel', False)
        self.actor_model_input_dim = (16 + 16) * self.bc_model_history_length
        self.critic_model_input_dim = (16 + 16) * self.bc_model_history_length
        
        # 
        
        self.compensator_output_joint_idxes = self.env.compensator_output_joint_idxes
        # print(f"compensator_output_joint_idxes: {self.compensator_output_joint_idxes}")
        self.use_masked_action_compensator = self.env.use_masked_action_compensator
        
        
        # try:
        #     self.compensator_considering_joint_idxes = self.env.compensator_considering_joint_idxes
        # except:
        #     self.compensator_considering_joint_idxes = [ self.compensator_output_joint_idxes[0].item() ]
        
        # if len(self.compensator_considering_joint_idxes) > 1:
        #     self.joint_idx_to_model = {}
        #     self.joint_idx_to_running_mean_std = {}
        #     self.joint_idx_to_value_mean_std = {}
        #     for cur_joint_idx in self.compensator_considering_joint_idxes:
        #         cur_joint_net_config = {
        #             'actor_units': self.network_config.mlp.units,
        #             'priv_mlp_units': self.network_config.priv_mlp.units,
        #             'actions_num': 1 , 
        #             'input_shape': (self.action_compensator_obs_dim, ),
        #             'priv_info': False,
        #             'proprio_adapt': self.proprio_adapt,
        #             'priv_info_dim': self.priv_info_dim,
        #         }
        #         self.joint_idx_to_model[cur_joint_idx] = ActorCritic(cur_joint_net_config)
        #         self.joint_idx_to_model[cur_joint_idx].to(self.device)
        #         self.joint_idx_to_running_mean_std[cur_joint_idx] = RunningMeanStd((self.action_compensator_obs_dim, )).to(self.device)
        #         self.joint_idx_to_value_mean_std[cur_joint_idx] = RunningMeanStd((1,)).to(self.device)
        # else:
        # ---- Model ----
        
        ### 
        if self.train_action_compensator and self.per_joint_action_compensator:
            compensator_input_dim = 21 if self.train_action_compensator_uan else 2
            # self.obs_shape = (compensator_input_dim, )
            net_config = {
                'actor_units': self.network_config.mlp.units,
                'priv_mlp_units': self.network_config.priv_mlp.units,
                'actions_num': 1 ,
                'input_shape': (compensator_input_dim, ),
                'priv_info': False, # 
                'proprio_adapt': self.proprio_adapt, # 
                'priv_info_dim': self.priv_info_dim, # priv info dim #
            }
        else:
            net_config = {
                'actor_units': self.network_config.mlp.units,
                'priv_mlp_units': self.network_config.priv_mlp.units,
                'actions_num': self.actions_num if (not self.train_action_compensator_w_real_wm) else self.compensator_output_joint_idxes.shape[0],
                'input_shape': self.obs_shape if (not self.train_action_compensator_w_real_wm) else (self.action_compensator_obs_dim, ),
                'priv_info': self.priv_info if (not self.train_action_compensator_w_real_wm) else False,
                'proprio_adapt': self.proprio_adapt,
                'priv_info_dim': self.priv_info_dim,
            }
        
        if self.tune_bc_model: # tune bc model #
            net_config['tune_bc_via_compensator_model'] = self.tune_bc_via_compensator_model
            self.model = ActorCriticAsymmetricBC(net_config)
            self.model.actor_model_input_dim = self.actor_model_input_dim
            self.model.critic_model_input_dim = self.critic_model_input_dim
            self.model.allegro_hand_dof_lower_limits  = self.env.allegro_hand_dof_lower_limits
            self.model.allegro_hand_dof_upper_limits  = self.env.allegro_hand_dof_upper_limits
        elif self.use_masked_action_compensator:
            net_config['target_joint_idx_tensor'] = self.compensator_output_joint_idxes
            self.model = ActorCriticMaskCompensator(net_config)
            self.action_compensator_obs_dim = 32 * self.env.compensator_history_length + 16
        else:
            self.model = ActorCritic(net_config)
        self.model.to(self.device)
        
        
        if self.train_action_compensator and self.per_joint_action_compensator:
            # obs_shape = (compensator_input_dim, )
            # self.running_mean_std = RunningMeanStd(obs_shape).to(self.device)
            
            self.model.train_action_compensator = self.train_action_compensator
            self.model.per_joint_action_compensator = self.per_joint_action_compensator
            self.model.train_action_compensator_uan = self.train_action_compensator_uan
            
            self.running_mean_std = RunningMeanStd(self.obs_shape if (not self.train_action_compensator_w_real_wm) else (self.action_compensator_obs_dim, )).to(self.device)
        else:
            self.model.train_action_compensator = False
            self.model.per_joint_action_compensator = False
            self.model.train_action_compensator_uan = False
            
            self.running_mean_std = RunningMeanStd(self.obs_shape if (not self.train_action_compensator_w_real_wm) else (self.action_compensator_obs_dim, )).to(self.device)
        
        self.value_mean_std = RunningMeanStd((1,)).to(self.device)


        # ---- Load Policy ----
        if self.train_action_compensator_w_real_wm:
            
            if self.multi_base_policy:
                self.policy_idx_to_model_ckpt_fn = self.policy_model_ckpt_fn.split('ANDOBJ')
                self.policy_idx_to_model = {}
                self.policy_idx_to_running_mean_std = {}
                self.sorted_policy_idxes = list(sorted(range(len(self.policy_idx_to_model_ckpt_fn))))
                for cur_policy_idx in self.sorted_policy_idxes:
                    cur_ckpt_fn = self.policy_idx_to_model_ckpt_fn[cur_policy_idx]
                    policy_net_config = {
                        'actor_units': self.network_config.mlp.units,
                        'priv_mlp_units': self.network_config.priv_mlp.units,
                        'actions_num': self.actions_num,
                        'input_shape': self.obs_shape,
                        'priv_info': self.priv_info,
                        'proprio_adapt': self.proprio_adapt,
                        'priv_info_dim': self.priv_info_dim,
                    }
                    self.policy_idx_to_model[cur_policy_idx] = ActorCritic(policy_net_config)
                    self.policy_idx_to_model[cur_policy_idx].to(self.device)
                    self.policy_idx_to_running_mean_std[cur_policy_idx] = RunningMeanStd(self.obs_shape).to(self.device)
                    
                    policy_ckpt = torch.load(cur_ckpt_fn)
                    self.policy_idx_to_model[cur_policy_idx].load_state_dict(policy_ckpt['model'])
                    self.policy_idx_to_running_mean_std[cur_policy_idx].load_state_dict(policy_ckpt['running_mean_std'])
                    
                    self.policy_idx_to_model[cur_policy_idx].eval()
                    self.policy_idx_to_running_mean_std[cur_policy_idx].eval()
            else:
                policy_net_config = {
                    'actor_units': self.network_config.mlp.units,
                    'priv_mlp_units': self.network_config.priv_mlp.units,
                    'actions_num': self.actions_num,
                    'input_shape': self.obs_shape,
                    'priv_info': self.priv_info,
                    'proprio_adapt': self.proprio_adapt,
                    'priv_info_dim': self.priv_info_dim,
                }
                self.policy_model = ActorCritic(policy_net_config)
                self.policy_model.to(self.device)
                self.policy_model_running_mean_std = RunningMeanStd(self.obs_shape).to(self.device)
                self.policy_model.eval()
                self.policy_model_running_mean_std.eval()
                # load policy model checkpoint #
                policy_ckpt = torch.load(self.policy_model_ckpt_fn)
                self.policy_model.load_state_dict(policy_ckpt['model'])
                self.policy_model_running_mean_std.load_state_dict(policy_ckpt['running_mean_std'])
        
        
            
        
            if self.use_bc_base_policy:
                self.invdyn_v2_config_path = 'controlseq.yml'
                with open(os.path.join("../IsaacGymEnvs2/isaacgymenvs/ddim/configs", self.invdyn_v2_config_path), "r") as f:
                    config = yaml.safe_load(f)
                invdyn_config = dict2namespace(config)
                invdyn_config.device = self.device
                invdyn_config.invdyn.model_arch = 'resmlp'
                invdyn_config.invdyn.res_blocks = 2
                invdyn_config.invdyn.pred_extrin = False
                
                self.bc_history_length = 10
                # self.history_length = 4
                
                invdyn_config.invdyn.history_length = self.bc_history_length #  10
                invdyn_config.invdyn.future_length = 2
                invdyn_config.invdyn.res_blocks = 5
                
                invdyn_config.invdyn.future_ref_dim = 3
                invdyn_config.invdyn.pred_extrin = False
                
                class dummy:
                    def __init__(self):
                        self.log_path = ''
                        self.sample_type = 'generalized'
                        self.skip_type = 'uniform'
                        self.timesteps = 50
                        self.eta = 0
                        self.model_type = 'invdyn'
                        # self.optimize_via_fingertip_pos 
                
                
                invdyn_args = dummy()
                invdyn_args.log_path = self.bc_base_policy_ckpt_fn
                
                self.bc_policy = DiffusionControlSeq(invdyn_args, invdyn_config)
                self.bc_policy.init_models(ckpt_fn=None)
            
            if self.hierarchical_compensator:
                # use the hierarhical compensator to predict the action #
                self.joint_idx_to_compensator_ckpt_fn = {
                    8: 'outputs/LeapHora/debug_true_cuboidthin_tunecompensatorviarl_finger3_jt8_jtsinputoutput_woresidual_wm2_scl1d24_v10/stage1_nn/best.pth',
                    12: 'outputs/LeapHora/debug_true_cuboidthin_tunecompensatorviarl_finger3_jt12_jtsinputoutput_woresidual_wm2_scl1d24_v10/stage1_nn/best.pth'
                }
                self.joint_idx_to_compensator = {}
                self.joint_idx_to_compensator_running_mean_std = {}
                compensator_input_dim = self.env.wm_history_length * 2  + 1 #  3
                for joint_idx in self.joint_idx_to_compensator_ckpt_fn:
                    cur_ckpt_fn = self.joint_idx_to_compensator_ckpt_fn[joint_idx]
                    compensator_obs_shape = (compensator_input_dim, )
                    policy_net_config = {
                        'actions_num': 1, 
                        'input_shape': compensator_obs_shape,
                        'actor_units': [512, 256, 128],
                        'priv_mlp_units': [256, 128, 8],
                        'priv_info': False,
                        'proprio_adapt': False,
                        'priv_info_dim': 9,
                    }
                    self.joint_idx_to_compensator[joint_idx] = ActorCritic(policy_net_config)
                    self.joint_idx_to_compensator[joint_idx].to(self.device)
                    self.joint_idx_to_compensator_running_mean_std[joint_idx] = RunningMeanStd(compensator_obs_shape).to(self.device)
                    
                    cur_action_compensator_ckpt = torch.load(self.joint_idx_to_compensator_ckpt_fn[joint_idx])
                    self.joint_idx_to_compensator[joint_idx].load_state_dict(cur_action_compensator_ckpt['model'])
                    self.joint_idx_to_compensator_running_mean_std[joint_idx].load_state_dict(cur_action_compensator_ckpt['running_mean_std'])
                    
                    self.joint_idx_to_compensator[joint_idx].eval()
                    self.joint_idx_to_compensator_running_mean_std[joint_idx].eval()

        
        
        if self.tune_bc_model:
            actor_ckpt_fn = self.ppo_config.get('actorModelCheckpoint', '')
            self.model.load_actor_models(actor_ckpt_fn)
            self.running_mean_std = RunningMeanStd((357, )).to(self.device)
            
            if not self.tune_bc_via_compensator_model:
                critic_ckpt_fn = self.ppo_config.get('criticModelCheckpoint', '')
                self.model.load_critic_models(critic_ckpt_fn)
                ori_policy_ckpt_fn = self.ppo_config.get('oriPolicyModelCheckpoint', '')
                policy_ckpt = torch.load(ori_policy_ckpt_fn)
                self.model.load_value_actor_mlp_and_value(ori_policy_ckpt_fn) # load value network #
                
                self.running_mean_std.load_state_dict(policy_ckpt['running_mean_std'])
                self.value_mean_std.load_state_dict(policy_ckpt['value_mean_std'])
                
            
        
        # ---- Output Dir ----
        # allows us to specify a folder where all experiments will reside
        self.output_dir = output_dif
        self.nn_dir = os.path.join(self.output_dir, 'stage1_nn')
        self.tb_dif = os.path.join(self.output_dir, 'stage1_tb')
        os.makedirs(self.nn_dir, exist_ok=True)
        os.makedirs(self.tb_dif, exist_ok=True)
        # ---- Optim ----
        self.last_lr = float(self.ppo_config['learning_rate'])
        self.weight_decay = self.ppo_config.get('weight_decay', 0.0)
        
        if self.tune_bc_model:
            
            if self.tune_bc_via_compensator_model:
                tot_parameters = []
                for name, param in self.model.named_parameters():
                    if 'actor_mlp' not in name:
                        tot_parameters.append(param)
            else:
                tot_parameters = []
                for name, param in self.model.named_parameters():
                    if 'value_actor_mlp.' not in name and 'value.' not in name:
                        tot_parameters.append(param)
            self.optimizer = torch.optim.Adam(tot_parameters, self.last_lr, weight_decay=self.weight_decay)
            
            if not self.tune_bc_via_compensator_model:
                self.running_mean_std.eval()
                self.value_mean_std.eval()
        else:
            self.optimizer = torch.optim.Adam(self.model.parameters(), self.last_lr, weight_decay=self.weight_decay)
        
        
        # ---- PPO Train Param ----
        self.e_clip = self.ppo_config['e_clip']
        self.clip_value = self.ppo_config['clip_value']
        self.entropy_coef = self.ppo_config['entropy_coef']
        self.critic_coef = self.ppo_config['critic_coef']
        self.bounds_loss_coef = self.ppo_config['bounds_loss_coef']
        self.gamma = self.ppo_config['gamma']
        self.tau = self.ppo_config['tau']
        self.truncate_grads = self.ppo_config['truncate_grads']
        self.grad_norm = self.ppo_config['grad_norm']
        self.value_bootstrap = self.ppo_config['value_bootstrap']
        self.normalize_advantage = self.ppo_config['normalize_advantage']
        self.normalize_input = self.ppo_config['normalize_input']
        self.normalize_value = self.ppo_config['normalize_value']
        # ---- PPO Collect Param ----
        self.horizon_length = self.ppo_config['horizon_length']
        self.batch_size = self.horizon_length * self.num_actors
        self.minibatch_size = self.ppo_config['minibatch_size']
        self.mini_epochs_num = self.ppo_config['mini_epochs']
        assert self.batch_size % self.minibatch_size == 0 or full_config.test
        # ---- scheduler ----
        self.kl_threshold = self.ppo_config['kl_threshold']
        self.scheduler = AdaptiveScheduler(self.kl_threshold)
        # ---- Snapshot
        self.save_freq = self.ppo_config['save_frequency']
        self.save_best_after = self.ppo_config['save_best_after']
        # ---- Tensorboard Logger ----
        self.extra_info = {}
        writer = SummaryWriter(self.tb_dif)
        self.writer = writer

        self.episode_rewards = AverageScalarMeter(100)
        self.episode_lengths = AverageScalarMeter(100)
        self.episode_rewards_pose_guidance = AverageScalarMeter(100)
        self.episode_rewards_pose_guidance_bonus = AverageScalarMeter(100)
        self.episode_rewards_wo_bonus = AverageScalarMeter(100)
        self.episode_obj_rot_vel = Average1DTensorMeter(window_size=20000, tensor_size=3)
        self.obs = None
        self.epoch_num = 0
        
        if self.train_action_compensator_w_real_wm:
            self.storage = ExperienceBuffer(
                self.num_actors, self.horizon_length, self.batch_size, self.minibatch_size, self.action_compensator_obs_dim,
                self.actions_num if (not self.train_action_compensator_w_real_wm) else self.compensator_output_joint_idxes.shape[0] , self.priv_info_dim, self.device,
            )
        else:
            self.storage = ExperienceBuffer(
                self.num_actors, self.horizon_length, self.batch_size, self.minibatch_size, self.obs_shape[0],
                self.actions_num, self.priv_info_dim, self.device,
            )
            if self.tune_bc_model:
                self.storage.storage_dict['ori_obs'] = torch.zeros((self.horizon_length, self.num_actors, self.model.original_obs_dim), dtype=torch.float32, device=self.device)
                self.storage.tune_bc_model = True

        batch_size = self.num_actors
        current_rewards_shape = (batch_size, 1)
        self.current_rewards = torch.zeros(current_rewards_shape, dtype=torch.float32, device=self.device)
        self.current_lengths = torch.zeros(batch_size, dtype=torch.float32, device=self.device)
        self.current_rewards_pose_guidance = torch.zeros(current_rewards_shape, dtype=torch.float32, device=self.device)
        self.current_rewards_pose_guidance_bonus = torch.zeros(current_rewards_shape, dtype=torch.float32, device=self.device)
        self.current_rewards_wo_bonus = torch.zeros(current_rewards_shape, dtype=torch.float32, device=self.device)
        
        self.dones = torch.ones((batch_size,), dtype=torch.uint8, device=self.device)
        self.agent_steps = 0
        self.max_agent_steps = self.ppo_config['max_agent_steps']
        self.best_rewards = -10000
        # ---- Timing
        self.data_collect_time = 0
        self.rl_train_time = 0
        self.all_time = 0

    def write_stats(self, a_losses, c_losses, b_losses, entropies, kls):
        self.writer.add_scalar('performance/RLTrainFPS', self.agent_steps / self.rl_train_time, self.agent_steps)
        self.writer.add_scalar('performance/EnvStepFPS', self.agent_steps / self.data_collect_time, self.agent_steps)

        self.writer.add_scalar('losses/actor_loss', torch.mean(torch.stack(a_losses)).item(), self.agent_steps)
        self.writer.add_scalar('losses/bounds_loss', torch.mean(torch.stack(b_losses)).item(), self.agent_steps)
        self.writer.add_scalar('losses/critic_loss', torch.mean(torch.stack(c_losses)).item(), self.agent_steps)
        self.writer.add_scalar('losses/entropy', torch.mean(torch.stack(entropies)).item(), self.agent_steps)

        self.writer.add_scalar('info/last_lr', self.last_lr, self.agent_steps)
        self.writer.add_scalar('info/e_clip', self.e_clip, self.agent_steps)
        self.writer.add_scalar('info/kl', torch.mean(torch.stack(kls)).item(), self.agent_steps)

        for k, v in self.extra_info.items():
            self.writer.add_scalar(f'{k}', v, self.agent_steps)

    def set_eval(self):
        self.model.eval()
        if self.normalize_input:
            self.running_mean_std.eval()
        if self.normalize_value:
            self.value_mean_std.eval()

    def set_train(self):
        self.model.train()
        if self.normalize_input:
            self.running_mean_std.train()
        if self.normalize_value:
            self.value_mean_std.train()


    def model_act(self, obs_dict):
        if self.tune_bc_model:
            res_dict = self.model_act_for_bc_model(obs_dict)
            return res_dict
        
        if self.train_action_compensator_w_real_wm:
            res_dict = self.model_act_for_action_compensator(obs_dict)
            return res_dict
        
        processed_obs = self.running_mean_std(obs_dict['obs'])
        input_dict = {
            'obs': processed_obs,
            'priv_info': obs_dict['priv_info'],
        }
        res_dict = self.model.act(input_dict)
        res_dict['values'] = self.value_mean_std(res_dict['values'], True)
        return res_dict
    
    def model_act_for_action_compensator(self, obs_dict):
        
        
        if self.use_bc_base_policy:
            history_qpos = self.env.obs_buf_lag_history_qpos[..., -self.bc_history_length: ]
            history_qtars = self.env.obs_buf_lag_history_qtars[..., -self.bc_history_length: ]
            flatten_history_qpos = history_qpos.view(history_qpos.size(0), -1).contiguous()
            flatten_history_qtars = history_qtars.view(history_qtars.size(0), -1).contiguous()
            
            history_input = torch.cat([flatten_history_qpos, flatten_history_qtars], dim=-1)
            future_input = torch.zeros((flatten_history_qpos.size(0), 6), dtype=torch.float32, device=self.device)
            pred_target = self.bc_policy.forward_states_for_actions(history_input, future_input, history_extrin=None, hist_context=None)
            cur_targets = pred_target[..., :16]
            self.env.bc_policy_pred_targets = cur_targets
        else:
            if self.multi_base_policy:
                policy_obs = obs_dict['obs']
                tot_mus = []
                for cur_policy_idx in self.sorted_policy_idxes:
                    cur_policy_obs = self.policy_idx_to_running_mean_std[cur_policy_idx](policy_obs)
                    cur_policy_input_dict = {
                        'obs': cur_policy_obs, 
                        'priv_info': obs_dict['priv_info']
                    }
                    cur_policy_mu = self.policy_idx_to_model[cur_policy_idx].act_inference(cur_policy_input_dict)
                    cur_policy_mu = torch.clamp(cur_policy_mu, -1.0, 1.0)
                    tot_mus.append(cur_policy_mu.clone())
                tot_mus = torch.stack(tot_mus, dim=1)
                envs_policy_idx = self.env.envs_policy_idx
                mu = batched_index_select(tot_mus, envs_policy_idx.unsqueeze(1), dim=1).squeeze(1)
            else:
                policy_obs = obs_dict['obs']
                policy_obs = self.policy_model_running_mean_std(policy_obs)
                policy_input_dict = {
                    'obs': policy_obs, 
                    'priv_info': obs_dict['priv_info']
                }
                
                mu = self.policy_model.act_inference(policy_input_dict)
                mu = torch.clamp(mu, -1.0, 1.0) # clamp mu to proper values #
                
            self.policy_act = mu.clone()
            cur_targets = self.env.prev_targets + 1/24 * mu # compute the current targets #
        compensator_obs = obs_dict['compensator_obs']
        if self.use_masked_action_compensator:
            if self.hierarchical_compensator:
                # first order compensator scale = 1/24 #
                
                compensator_hist_obs, compensator_hist_actions = compensator_obs[..., : -16], compensator_obs[..., -16: ]
                first_level_compensated_targets = cur_targets.clone()
                for joint_idx in self.joint_idx_to_compensator:
                    cur_joint_hist_obs = compensator_hist_obs[..., joint_idx: joint_idx + 1]
                    cur_joint_hist_actions = compensator_hist_actions[..., joint_idx: joint_idx + 1]
                    # cur_joint_hist_obs = cur_joint_hist_obs.contiguous().view(cur_joint_hist_obs.shape[0], -1).contiguous()
                    # cur_joint_hist_actions = cur_joint_hist_actions.contiguous().view(cur_joint_hist_actions.shape[0], -1).contiguous()
                    cur_joint_compensator_input = torch.cat([cur_joint_hist_obs, cur_joint_hist_actions, cur_targets[..., joint_idx: joint_idx + 1]], dim=-1)
                    cur_joint_compensator_input = self.joint_idx_to_compensator_running_mean_std[joint_idx](cur_joint_compensator_input)
                    cur_joint_compensator_input_dict = {
                        'obs': cur_joint_compensator_input,
                    }
                    first_level_delta_action = self.joint_idx_to_compensator[joint_idx].act_inference(cur_joint_compensator_input_dict)
                    first_level_delta_action = torch.clamp(first_level_delta_action, -1.0, 1.0)
                    first_level_compensated_targets[..., joint_idx: joint_idx + 1] = cur_targets[..., joint_idx: joint_idx + 1] + 1/24 * first_level_delta_action
                cur_targets = first_level_compensated_targets
                # ## herarchical compensator -- in pre-physics step, update cur_targets using the first-level-compensated targets and the output actions ## #
                self.env.first_level_compensated_targets = first_level_compensated_targets.clone()
                
            compensator_obs  = torch.cat([compensator_obs, cur_targets], dim=1)
        else:
            compensator_obs  = torch.cat([compensator_obs, cur_targets[..., self.env.compensator_input_joint_idxes]], dim=1)
        self.compensator_obs = compensator_obs.clone() 
        compensator_obs = self.running_mean_std(compensator_obs)
        compensator_input_dict = {
            'obs': compensator_obs,
            'priv_info': obs_dict['priv_info']
        }
        # delta_actions = self.model(compensator_input_dict)
        res_dict = self.model.act(compensator_input_dict)
        res_dict['values'] = self.value_mean_std(res_dict['values'], True)
        return res_dict
    
    def model_act_for_bc_model(self, obs_dict):
        
        obs = obs_dict['obs']
        actor_obs = obs[..., : self.actor_model_input_dim]
        critic_obs = obs[..., self.actor_model_input_dim: ]
        
        
        obs_dict = {
            'ori_obs': self.running_mean_std(self.env.original_obs_buf[..., : self.model.original_obs_dim]),
            'actor_obs': actor_obs, 
            'critic_obs': critic_obs,
            'priv_info': obs_dict['priv_info'],
        }
        res_dict = self.model.act(obs_dict)
        res_dict['values'] = self.value_mean_std(res_dict['values'], True)
        
        if self.tune_bc_via_compensator_model:
            self.env.bc_model_actions = self.model.bc_model_actions.detach()
        
        return res_dict
    
    
    def train(self):
        _t = time.time()
        _last_t = time.time()
        self.obs = self.env.reset()
        self.agent_steps = self.batch_size
        
        mean_rewards = self.episode_rewards.get_mean()
        mean_rewards_pose_guidance = self.episode_rewards_pose_guidance.get_mean()
        mean_rewards_pose_guidance_bonus = self.episode_rewards_pose_guidance_bonus.get_mean()
        mean_rewards_wo_bonus = self.episode_rewards_wo_bonus.get_mean()
        mean_obj_rot_vel = self.episode_obj_rot_vel.get_mean()

        while self.agent_steps < self.max_agent_steps:
            self.epoch_num += 1
            a_losses, c_losses, b_losses, entropies, kls = self.train_epoch()
            self.storage.data_dict = None

            all_fps = self.agent_steps / (time.time() - _t)
            last_fps = self.batch_size / (time.time() - _last_t)
            _last_t = time.time()
            info_string = f'Agent Steps: {int(self.agent_steps // 1e6):04}M | FPS: {all_fps:.1f} | ' \
                          f'Last FPS: {last_fps:.1f} | ' \
                          f'Collect Time: {self.data_collect_time / 60:.1f} min | ' \
                          f'Train RL Time: {self.rl_train_time / 60:.1f} min | ' \
                          f'Current Best: {self.best_rewards:.2f} | ' \
                          f'Current Mean: {mean_rewards:.2f} | ' \
                          f'Current Mean w/o Bonus: {mean_rewards_wo_bonus:.2f} | ' \
                          f'Current Mean Pose Guidance: {mean_rewards_pose_guidance:.2f} | ' \
                          f'Current Mean Pose Guidance Bonus: {mean_rewards_pose_guidance_bonus:.2f} | ' \
                          f'Avg Rot-Vel: {mean_obj_rot_vel}'
            print(info_string)

            self.write_stats(a_losses, c_losses, b_losses, entropies, kls)

            mean_rewards = self.episode_rewards.get_mean()
            mean_lengths = self.episode_lengths.get_mean()
            
            mean_rewards_pose_guidance = self.episode_rewards_pose_guidance.get_mean()
            mean_rewards_pose_guidance_bonus = self.episode_rewards_pose_guidance_bonus.get_mean()
            mean_rewards_wo_bonus = self.episode_rewards_wo_bonus.get_mean()
            
            mean_obj_rot_vel = self.episode_obj_rot_vel.get_mean()
            
            self.writer.add_scalar('episode_rewards/step', mean_rewards, self.agent_steps)
            self.writer.add_scalar('episode_lengths/step', mean_lengths, self.agent_steps)
            checkpoint_name = f'ep_{self.epoch_num}_step_{int(self.agent_steps // 1e6):04}M_reward_{mean_rewards:.2f}'

            if self.save_freq > 0:
                if self.epoch_num % self.save_freq == 0:
                    self.save(os.path.join(self.nn_dir, checkpoint_name))
                    self.save(os.path.join(self.nn_dir, 'last'))

            if mean_rewards > self.best_rewards and self.epoch_num >= self.save_best_after:
                print(f'save current best reward: {mean_rewards:.2f}')
                self.best_rewards = mean_rewards
                self.save(os.path.join(self.nn_dir, 'best'))
                self.save(os.path.join(self.nn_dir, f'best_{mean_rewards}'))

        print('max steps achieved')

    def save(self, name):
        weights = {
            'model': self.model.state_dict(),
        }
        if self.running_mean_std:
            weights['running_mean_std'] = self.running_mean_std.state_dict()
        if self.value_mean_std:
            weights['value_mean_std'] = self.value_mean_std.state_dict()
        torch.save(weights, f'{name}.pth')



    def restore_train(self, fn):
        if not fn:
            return
        
        checkpoint = torch.load(fn)
        print(f"Loading  from {fn}")
        
        self.model.load_state_dict(checkpoint['model'])
        self.running_mean_std.load_state_dict(checkpoint['running_mean_std'])
        self.value_mean_std.load_state_dict(checkpoint['value_mean_std'])

    def restore_test(self, fn):
        checkpoint = torch.load(fn)
        self.model.load_state_dict(checkpoint['model'])
        if self.normalize_input:
            self.running_mean_std.load_state_dict(checkpoint['running_mean_std'])

    def test(self):
        self.set_eval()
        obs_dict = self.env.reset()
        while True:
            if self.tune_bc_model:
                input_dict = {
                    'ori_obs': self.running_mean_std(self.env.original_obs_buf[..., : self.model.original_obs_dim]),
                    'priv_info': obs_dict['priv_info'],
                }
            else:
                input_dict = {
                    'obs': self.running_mean_std(obs_dict['obs']),
                    'priv_info': obs_dict['priv_info'],
                }
            
            actor_obs = obs_dict['obs'][..., : self.actor_model_input_dim]
            critic_obs = obs_dict['obs'][..., self.actor_model_input_dim: ]
            input_dict.update({
                'actor_obs': actor_obs,
                'critic_obs': critic_obs,
            })
            
            # critic obs # actor obs #
            mu = self.model.act_inference(input_dict) # value of values # # value of values #
            
            if self.tune_bc_model and self.tune_bc_via_compensator_model:
                self.env.bc_model_actions = self.model.bc_model_actions.detach()
            
            self.env.value_vals = self.model.value_vals # value of values -- use that as the values -- #
            # if not self.tune_bc_model:
            mu = torch.clamp(mu, -1.0, 1.0)
            try:
                self.env.extrin = self.model.extrin.clone()
            except:
                pass
            obs_dict, r, done, info = self.env.step(mu)

    def train_epoch(self):
        # collect minibatch data
        _t = time.time()
        self.set_eval()
        self.play_steps()
        self.data_collect_time += (time.time() - _t)
        # update network
        _t = time.time()
        self.set_train()
        a_losses, b_losses, c_losses = [], [], []
        entropies, kls = [], []
        for _ in range(0, self.mini_epochs_num):
            ep_kls = []
            for i in range(len(self.storage)):
                
                if self.tune_bc_model: 
                    value_preds, old_action_log_probs, advantage, old_mu, old_sigma, \
                        returns, actions, obs, priv_info, ori_obs = self.storage[i]
                        
                    ori_obs = self.running_mean_std(ori_obs)
                    batch_dict = {
                        'ori_obs': ori_obs,
                        'prev_actions': actions,
                        'obs': obs,
                        'priv_info': priv_info,
                    }
                else:
                    value_preds, old_action_log_probs, advantage, old_mu, old_sigma, \
                        returns, actions, obs, priv_info = self.storage[i]

                    if not self.tune_bc_model:
                        obs = self.running_mean_std(obs)
                    
                    batch_dict = {
                        'prev_actions': actions,
                        'obs': obs,
                        'priv_info': priv_info,
                    }
                
                res_dict = self.model(batch_dict)
                action_log_probs = res_dict['prev_neglogp']
                values = res_dict['values']
                entropy = res_dict['entropy']
                mu = res_dict['mus']
                sigma = res_dict['sigmas']

                # actor loss
                ratio = torch.exp(old_action_log_probs - action_log_probs)
                surr1 = advantage * ratio
                surr2 = advantage * torch.clamp(ratio, 1.0 - self.e_clip, 1.0 + self.e_clip)
                a_loss = torch.max(-surr1, -surr2)
                # critic loss
                value_pred_clipped = value_preds + (values - value_preds).clamp(-self.e_clip, self.e_clip)
                value_losses = (values - returns) ** 2
                value_losses_clipped = (value_pred_clipped - returns) ** 2
                c_loss = torch.max(value_losses, value_losses_clipped)
                # bounded loss
                if self.bounds_loss_coef > 0:
                    soft_bound = 1.1
                    mu_loss_high = torch.clamp_max(mu - soft_bound, 0.0) ** 2
                    mu_loss_low = torch.clamp_max(-mu + soft_bound, 0.0) ** 2
                    b_loss = (mu_loss_low + mu_loss_high).sum(axis=-1)
                else:
                    b_loss = 0
                a_loss, c_loss, entropy, b_loss = [torch.mean(loss) for loss in [a_loss, c_loss, entropy, b_loss]]

                loss = a_loss + 0.5 * c_loss * self.critic_coef - entropy * self.entropy_coef + b_loss * self.bounds_loss_coef

                self.optimizer.zero_grad()
                loss.backward()
                if self.truncate_grads:
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_norm)
                self.optimizer.step()

                with torch.no_grad():
                    kl_dist = policy_kl(mu.detach(), sigma.detach(), old_mu, old_sigma)

                kl = kl_dist
                a_losses.append(a_loss)
                c_losses.append(c_loss)
                ep_kls.append(kl)
                entropies.append(entropy)
                if self.bounds_loss_coef is not None:
                    b_losses.append(b_loss)

                self.storage.update_mu_sigma(mu.detach(), sigma.detach())

            av_kls = torch.mean(torch.stack(ep_kls))
            self.last_lr = self.scheduler.update(self.last_lr, av_kls.item())
            for param_group in self.optimizer.param_groups:
                param_group['lr'] = self.last_lr
            kls.append(av_kls)

        self.rl_train_time += (time.time() - _t)
        return a_losses, c_losses, b_losses, entropies, kls

    def play_steps(self):
        for n in range(self.horizon_length): # 
            res_dict = self.model_act(self.obs)
            # collect o_t
            if self.train_action_compensator_w_real_wm:
                self.storage.update_data('obses', n, self.compensator_obs)
            else:
                self.storage.update_data('obses', n, self.obs['obs'])
            self.storage.update_data('priv_info', n, self.obs['priv_info'])
            for k in ['actions', 'neglogpacs', 'values', 'mus', 'sigmas']:
                # print(f"res_dict[k]:", res_dict[k].size())
                self.storage.update_data(k, n, res_dict[k])
            
            if self.tune_bc_model:
                ori_obs = self.env.original_obs_buf[..., : self.model.original_obs_dim]
                self.storage.update_data('ori_obs', n, ori_obs)
            
            # do env step
            if self.train_action_compensator_w_real_wm:
                # if not self.tune_bc_model:
                actions = torch.clamp(self.policy_act, -1.0, 1.0)
                # else:
                #     actions = self.policy_act
                # print("actions:", res_dict['actions'].shape)
                res_dict['actions'] = torch.clamp(res_dict['actions'], -1.0, 1.0)
                self.env.compensating_targets = res_dict['actions'].clone()
                self.obs, rewards, self.dones, infos = self.env.step(actions)
            else:
                # if not self.tune_bc_model:
                actions = torch.clamp(res_dict['actions'], -1.0, 1.0)
                # else:
                #     actions = res_dict['actions']
                self.obs, rewards, self.dones, infos = self.env.step(actions)
            rewards = rewards.unsqueeze(1)
            # update dones and rewards after env step
            self.storage.update_data('dones', n, self.dones)
            shaped_rewards = 0.01 * rewards.clone()
            if self.value_bootstrap and 'time_outs' in infos:
                shaped_rewards += self.gamma * res_dict['values'] * infos['time_outs'].unsqueeze(1).float()
            self.storage.update_data('rewards', n, shaped_rewards)

            self.current_rewards += rewards
            self.current_rewards_pose_guidance  += self.env.rew_buf_aux_pose_guidance.unsqueeze(1)
            self.current_rewards_pose_guidance_bonus += self.env.rew_buf_aux_pose_guidance_bonus.unsqueeze(1)
            self.current_rewards_wo_bonus += self.env.rew_buf_wo_aux.unsqueeze(1) 
            self.current_lengths += 1
            done_indices = self.dones.nonzero(as_tuple=False)
            self.episode_rewards.update(self.current_rewards[done_indices])
            self.episode_lengths.update(self.current_lengths[done_indices])
            self.episode_rewards_pose_guidance.update(self.current_rewards_pose_guidance[done_indices])
            self.episode_rewards_pose_guidance_bonus.update(self.current_rewards_pose_guidance_bonus[done_indices])
            self.episode_rewards_wo_bonus.update(self.current_rewards_wo_bonus[done_indices])
            self.episode_obj_rot_vel.update(self.env.object_angvel.cpu())

            assert isinstance(infos, dict), 'Info Should be a Dict'
            self.extra_info = {}
            for k, v in infos.items():
                # only log scalars
                if isinstance(v, float) or isinstance(v, int) or (isinstance(v, torch.Tensor) and len(v.shape) == 0):
                    self.extra_info[k] = v

            not_dones = 1.0 - self.dones.float()

            self.current_rewards = self.current_rewards * not_dones.unsqueeze(1)
            self.current_rewards_pose_guidance = self.current_rewards_pose_guidance * not_dones.unsqueeze(1)
            self.current_rewards_pose_guidance_bonus = self.current_rewards_pose_guidance_bonus * not_dones.unsqueeze(1)
            self.current_rewards_wo_bonus = self.current_rewards_wo_bonus * not_dones.unsqueeze(1)
            self.current_lengths = self.current_lengths * not_dones

        res_dict = self.model_act(self.obs)
        last_values = res_dict['values']

        self.agent_steps += self.batch_size
        self.storage.computer_return(last_values, self.gamma, self.tau)
        self.storage.prepare_training()

        returns = self.storage.data_dict['returns']
        values = self.storage.data_dict['values']
        if self.normalize_value:
            self.value_mean_std.train()
            values = self.value_mean_std(values)
            returns = self.value_mean_std(returns)
            self.value_mean_std.eval()
        self.storage.data_dict['values'] = values
        self.storage.data_dict['returns'] = returns


def policy_kl(p0_mu, p0_sigma, p1_mu, p1_sigma):
    c1 = torch.log(p1_sigma/p0_sigma + 1e-5)
    c2 = (p0_sigma ** 2 + (p1_mu - p0_mu) ** 2) / (2.0 * (p1_sigma ** 2 + 1e-5))
    c3 = -1.0 / 2.0
    kl = c1 + c2 + c3
    kl = kl.sum(dim=-1)  # returning mean between all steps of sum between all actions
    return kl.mean()


# from https://github.com/leggedrobotics/rsl_rl/blob/master/rsl_rl/algorithms/ppo.py
class AdaptiveScheduler(object):
    def __init__(self, kl_threshold=0.008):
        super().__init__()
        self.min_lr = 1e-6
        self.max_lr = 1e-2
        self.kl_threshold = kl_threshold

    def update(self, current_lr, kl_dist):
        lr = current_lr
        if kl_dist > (2.0 * self.kl_threshold):
            lr = max(current_lr / 1.5, self.min_lr)
        if kl_dist < (0.5 * self.kl_threshold):
            lr = min(current_lr * 1.5, self.max_lr)
        return lr
